import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pathlib

regions= np.arange(19)+1
mb_forcing_ids = np.arange(80)+1
participants=['Compagno','Huss','Kraaijenbrink','GLIMB'] 

data_excel=pd.read_excel("GMIP3-chains.xlsx", header=None)
plt.close('all')

# Length NetCDF files according to GlacierMIP3 GitHub page
netcdf_length=np.zeros((20,1))
netcdf_length[1]=5000
netcdf_length[2]=2000
netcdf_length[3]=5000
netcdf_length[4]=5000
netcdf_length[5]=5000
netcdf_length[6]=5000
netcdf_length[7]=5000
netcdf_length[8]=2000
netcdf_length[9]=5000
netcdf_length[10]=2000
netcdf_length[11]=2000
netcdf_length[12]=2000
netcdf_length[13]=2000
netcdf_length[14]=2000
netcdf_length[15]=2000
netcdf_length[16]=2000
netcdf_length[17]=5000
netcdf_length[18]=2000
netcdf_length[19]=5000

for dummy,region in enumerate(regions): # For every region --> make plot
    plt.figure(region,figsize=(20,10))
    plt.rcParams.update({'font.size': 8})
    plt.subplots_adjust(left=.03, bottom=.03, right=.97, top=.97)

    participant_counter=0
    for dummy,participant in enumerate(participants):
        participant_counter=participant_counter+1
        
        for dummy,i in enumerate(mb_forcing_ids):
            start_year=number1=data_excel.iloc[i-1][2]
            end_year=data_excel.iloc[i-1][2]+19
            period='{number1}-{number2}'.format(number1=start_year,number2=end_year)
            gcm=data_excel.iloc[i-1][0]
            if gcm=='ipsl-cm6a-lr':
                if participant=='Compagno':
                    gcm='psl_cm6a-lr' # no 'i' and '_' instead of '-'
            ssp=data_excel.iloc[i-1][1]
            if data_excel.iloc[i-1][2]<2020: # historical run, label should be 'historical' (not the case in SMB files Matthias)
                if participant=='Kraaijenbrink':
                    ssp='historical'
                elif participant=='Compagno' or participant=='GLIMB' or participant=='Rounce':
                    ssp='hist'
            participant_file_name=participant
            filename = f'{participant}/{participant_file_name}_rgi{region:{0}{2}}_sum_{period}_{gcm}_{ssp}.nc'
            #
            print(filename)
            file = pathlib.Path(filename)
            #
            
            if file.exists():
                ds_loaded = xr.open_dataset(filename)
                if participant=='Compagno':
                    vol_loaded=ds_loaded['volume_m3'] # already in km^3
                    vol_start=vol_loaded[0]
                else:
                    vol_loaded=ds_loaded['volume_m3']/1e9 # /1e9 to go from km^3 to m^3
                area_loaded=ds_loaded['area_m2']/1e6 # /1e6 to go from km^2 to m^2
                area_start=area_loaded[0]
                #
                plt.subplot(2,len(participants),participant_counter)
                plt.plot(np.arange(len(vol_loaded)),vol_loaded[:])
                
                plt.subplot(2,len(participants),len(participants)+participant_counter)
                plt.plot(np.arange(len(area_loaded)),area_loaded[:])
    
                
        # For every subplot for this region (every subplot is for single participant)
        plt.subplot(2,len(participants),participant_counter)
        plt.gca().set_xlim(0,netcdf_length[region])
        plt.gca().set_ylim(0,vol_start*2)
        plt.title(f'{participant}')
        plt.ylabel('Volume (km^3)')
        
        plt.subplot(2,len(participants),len(participants)+participant_counter)
        plt.gca().set_xlim(0,netcdf_length[region])
        plt.gca().set_ylim(0,area_start*2)
        plt.title(f'{participant}')
        plt.ylabel('Area (km^2)')